{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "配置网络环境"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"]=\"0,1,5,6\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "导入模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from peft import LoraConfig, TaskType, get_peft_model\n",
    "from transformers import AutoModel, HfArgumentParser, TrainingArguments\n",
    "from transformers import AutoTokenizer, AutoModelForCausalLM\n",
    "import torch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = AutoModelForCausalLM.from_pretrained(\n",
    "        \"/jydata/qwen/Qwen2-7B-Instruct\",device_map=\"auto\",torch_dtype=torch.bfloat16)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.dtype"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.allclose(model.word_embeddings.weight.data, model.lm_head.weight.data)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "output = []\n",
    "for name, param in model.named_parameters():\n",
    "    print(name)\n",
    "    if(name == \"base_model.model.model.embed_tokens.weight\"):\n",
    "        output.append(param)\n",
    "    if(name == \"base_model.model.lm_head.weight\"):\n",
    "        output.append(param)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "分析未加lora前的模型权重"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for name, param in model.named_parameters():\n",
    "        print(name, param.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "加lora后模型权重"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "config = LoraConfig(\n",
    "    task_type=TaskType.CAUSAL_LM, \n",
    "    target_modules=[\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\", \"gate_proj\", \"up_proj\", \"down_proj\"],\n",
    "    inference_mode=False, # 训练模式\n",
    "    r=8, # Lora 秩\n",
    "    lora_alpha=32, # Lora alaph，具体作用参见 Lora 原理\n",
    "    lora_dropout=0.1# Dropout 比例\n",
    ")\n",
    "model = get_peft_model(model, config)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def cal_param(param):\n",
    "    a = 1\n",
    "    for papr in param.shape:\n",
    "        a = a * papr\n",
    "    return a"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "all_parameters = 0\n",
    "lora_parameters = 0\n",
    "for name, param in model.named_parameters():\n",
    "    if(\"lora\" in name):\n",
    "        #print(name, param.shape)\n",
    "        lora_parameters = lora_parameters + cal_param(param)\n",
    "    all_parameters = all_parameters + cal_param(param)\n",
    "print(f\"lora参数量：{round(lora_parameters/10**9,2)}B\")\n",
    "print(f\"所有参数量：{round(all_parameters/10**9,2)}B\")\n",
    "print(f\"lora参数占比：{round(lora_parameters/all_parameters*100,2)}%\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "读取tokenizer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from modelscope import AutoModelForCausalLM, AutoTokenizer\n",
    "tokenizer = AutoTokenizer.from_pretrained(\"/jydata/qwen/Qwen2-7B-Instruct\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "查看模型中可训练参数"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "def print_trainable_parameters(model):\n",
    "    \"\"\"\n",
    "    Prints the number of trainable parameters in the model.\n",
    "    \"\"\"\n",
    "    trainable_params = 0\n",
    "    all_param = 0\n",
    "    for _, param in model.named_parameters():\n",
    "        all_param += param.numel()\n",
    "        if param.requires_grad:\n",
    "            trainable_params += param.numel()\n",
    "    print(\n",
    "        f\"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}\"\n",
    "    )\n",
    "print_trainable_parameters(model)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "数据预处理"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "import json\n",
    "data = []\n",
    "no = 0\n",
    " \n",
    "# 读取 JSON Lines 文件\n",
    "with open(\"./v1/1.jsonl\", 'r', encoding='utf-8') as file:\n",
    "    for line in file:\n",
    "        # 使用 json.loads() 将 JSON 格式的字符串解析为字典\n",
    "        entry = json.loads(line)\n",
    "        # 将解析后的字典添加到列表中\n",
    "        with open(\"./v1/1_modify_small.jsonl\", 'a', encoding='utf-8') as f:\n",
    "            if(no == 10):\n",
    "                break\n",
    "            f.write(json.dumps(entry[0],ensure_ascii=False))\n",
    "            f.write('\\n')\n",
    "            no = no + 1\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from datasets import load_dataset,load_from_disk\n",
    "data = load_dataset('json',data_files=\"./v1/1_modify_small.jsonl\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def handler(data):\n",
    "    data['response'] = data['response'][0][0]\n",
    "    return data\n",
    "\n",
    "datasetMap = data.map(handler)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "system_prompt ='''你是一位小说创作专家，你需要根据给定的要求续写文章，必须满足字数要求.'''\n",
    "\n",
    "def handler2(data):\n",
    "    data[\"prompt\"] = f'''<|im_start|>system\\n{system_prompt}<|im_end|>\\n<|im_start|>user\\n''' + data[\"prompt\"] + \"<|im_end|>\\n<|im_start|>assistant\\n\"\n",
    "    data[\"response\"] = data[\"response\"]\n",
    "    return data\n",
    "datasetMap2 = datasetMap.map(handler2)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def process_func2(example):\n",
    "    '''\n",
    "    将数据集进行预处理\n",
    "    '''\n",
    "\n",
    "    MAX_LENGTH =8096\n",
    "    input_ids, attention_mask, labels =[],[],[]\n",
    "\n",
    "    instruction = tokenizer(example['prompt'],add_special_tokens=False)\n",
    "    response = tokenizer(example['response'], add_special_tokens=False)\n",
    "    \n",
    "    input_ids = instruction['input_ids']+ response['input_ids']+[tokenizer.pad_token_id]\n",
    "    attention_mask =(\n",
    "        instruction['attention_mask']+ response['attention_mask']+[1]\n",
    "    )\n",
    "    labels =[-100]* len(instruction['input_ids'])+ response['input_ids']+[tokenizer.pad_token_id]\n",
    "    if len(input_ids)> MAX_LENGTH:# 做一个截断\n",
    "        input_ids = input_ids[:MAX_LENGTH]\n",
    "        attention_mask = attention_mask[:MAX_LENGTH]\n",
    "        labels = labels[:MAX_LENGTH]\n",
    "    return{'input_ids': input_ids,'attention_mask': attention_mask,'labels': labels}\n",
    "\n",
    "datasetMap3 = datasetMap2.map(process_func2)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "datasetMap3"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "    \n",
    "datasetMap4 = datasetMap3.map(remove_columns=datasetMap2[\"train\"].column_names)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "datasetMap4[\"train\"][0][\"labels\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import AutoModelForCausalLM,TrainingArguments,Trainer,DataCollatorForSeq2Seq\n",
    "from transformers import DataCollatorForLanguageModeling\n",
    "\n",
    "from datetime import datetime\n",
    "now = datetime.now()\n",
    "time_str = now.strftime('%Y-%m-%d %H:%M:%S')\n",
    "print(time_str)\n",
    "\n",
    "trainer = Trainer(\n",
    "    model=model,\n",
    "    train_dataset=datasetMap4[\"train\"],\n",
    "    args=TrainingArguments(\n",
    "        output_dir=\"./output/Qwen2_instruct_lora\",\n",
    "        per_device_train_batch_size=4,\n",
    "        gradient_accumulation_steps=4,\n",
    "        logging_steps=10,\n",
    "        num_train_epochs=3,\n",
    "        save_steps=100, # 为了快速演示，这里设置10，建议你设置成100\n",
    "        learning_rate=1e-4,\n",
    "        save_on_each_node=True,\n",
    "        gradient_checkpointing=True\n",
    "    ),\n",
    "    data_collator=DataCollatorForSeq2Seq(tokenizer=tokenizer, padding=True),\n",
    "    )\n",
    " \n",
    "model.config.use_cache = False  # silence the warnings. Please re-enable for inference!\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.enable_input_require_grads() "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "trainer.train()\n",
    " \n",
    "trainer.save_model(trainer.args.output_dir)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "llama_factory",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
